/**
 * Copyright 2019-2022 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "hiai_ir_aipp_compatible_adapter.h"

#include <algorithm>

#include "base/error_types.h"
#include "graph/op/image_defs.h"
#include "graph/op/array_defs.h"

#include "framework/infra/log/log.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/cgraph/compute_graph.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"
#include "framework/graph/core/cgraph/graph_sorter.h"
#include "framework/graph/core/node/node_walker.h"
#include "framework/graph/core/node/node_visitor.h"
#include "framework/graph/core/node/node.h"

#include "model_builder/ir/aipp/converter/aipp_param_info_converter.h"
#include "model_builder/ir/aipp/compatible/hiai_ir_aipp_compatible.h"

namespace hiai {
Status GenerateAippCompatibleInfoAdapter(ge::ComputeGraph& graph, std::string& customData)
{
    return HiAIIRAippCompatible::GenerateAippCompatibleInfo(graph, customData);
}

Status GetDataNodeOutNodes(const ge::Node& dataNode, std::vector<ge::Node*>& outNodes)
{
    auto getOutDataNode = [&outNodes](ge::Node& node) {
        outNodes.push_back(&node);
        return hiai::SUCCESS;
    };
    if (dataNode.ROLE(NodeWalker).ListOutDataNodes(std::move(getOutDataNode)) != hiai::SUCCESS) {
        FMK_LOGE("dataNode ListOutDataNodes failed");
        return hiai::FAILURE;
    }

    return hiai::SUCCESS;
}

Status GetNextAippNode(const ge::Node& aippNode, std::vector<ge::Node*>& aippNodeOutNodes)
{
    auto getAippNodeOutDataNode = [&aippNodeOutNodes](ge::Node& node) {
        aippNodeOutNodes.push_back(&node);
        return hiai::SUCCESS;
    };
    if (aippNode.ROLE(NodeWalker).ListOutDataNodes(0, std::move(getAippNodeOutDataNode)) != hiai::SUCCESS) {
        FMK_LOGE("aippNode ListOutDataNodes 0 failed");
        return hiai::FAILURE;
    }
    return hiai::SUCCESS;
}

Status FindConfigDataNodes(std::vector<ge::Node*>& outNodes, std::vector<ge::Node*>& configDataNodes)
{
    for (const ge::Node* aippNode : outNodes) {
        while (aippNode != nullptr &&
            std::find(AIPP_FUNC_TYPE.begin(), AIPP_FUNC_TYPE.end(), aippNode->ROLE(NodeSpec).Type()) !=
            AIPP_FUNC_TYPE.end()) {
            ge::Node* configNode = aippNode->ROLE(NodeWalker).InDataNode(1);
            HIAI_EXPECT_NOT_NULL(configNode);
            if (configNode->ROLE(NodeSpec).Type() == hiai::op::ConfigData::TYPE) {
                configDataNodes.push_back(configNode);
            }

            std::vector<ge::Node*> aippNodeOutNodes;
            HIAI_EXPECT_EXEC(GetNextAippNode(*aippNode, aippNodeOutNodes));

            if (aippNodeOutNodes.empty()) {
                break;
            }

            aippNode = aippNodeOutNodes[0];
        }
    }

    return hiai::SUCCESS;
}

Status GetConfigDataNodes(std::vector<ge::Node*>& dataNodes, std::vector<ge::Node*>& configDataNodes)
{
    for (const ge::Node* dataNode : dataNodes) {
        std::vector<ge::Node*> outNodes;
        HIAI_EXPECT_EXEC(GetDataNodeOutNodes(*dataNode, outNodes));
        HIAI_EXPECT_EXEC(FindConfigDataNodes(outNodes, configDataNodes));
    }

    return hiai::SUCCESS;
}

Status GetDataNodes(ge::ComputeGraph& graph, std::vector<ge::Node*>& dataNodes)
{
    auto getDataNodes = [&dataNodes](ge::Node& dataNode) {
        if (dataNode.ROLE(NodeSpec).Type() == hiai::op::Data::TYPE) {
            dataNodes.push_back(&dataNode);
        }
        return hiai::SUCCESS;
    };
    if (graph.ROLE(GraphListWalker).WalkAllNodes(std::move(getDataNodes)) != hiai::SUCCESS) {
        FMK_LOGE("graph WalkInNodes failed");
        return hiai::FAILURE;
    }

    return hiai::SUCCESS;
}

Status UpdateInputOrder(ge::ComputeGraph& graph)
{
    std::vector<ge::Node*> dataNodes;
    HIAI_EXPECT_EXEC(GetDataNodes(graph, dataNodes));

    std::vector<ge::Node*> configDataNodes;
    HIAI_EXPECT_EXEC(GetConfigDataNodes(dataNodes, configDataNodes));

    if (!configDataNodes.empty()) {
        std::vector<ge::Node*> inputNodes;
        inputNodes.insert(inputNodes.cend(), dataNodes.cbegin(), dataNodes.cend());
        inputNodes.insert(inputNodes.cend(), configDataNodes.cbegin(), configDataNodes.cend());

        if (graph.ROLE(GraphModifier).SetInputs(inputNodes) != hiai::SUCCESS) {
            return hiai::FAILURE;
        }
        FMK_LOGI("have configDataNodes");
        return graph.ROLE(GraphSorter).SortNodesDFS();
    }

    FMK_LOGI("have no configDataNodes");
    return hiai::SUCCESS;
}
} // namespace hiai
